edemana commited on
Commit
672e4c0
·
verified ·
1 Parent(s): d51fbd3

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +294 -0
model.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Copy NGANof .ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1w2sRg7uNq-lx67zg9Afcr58f2P_R-5ca
8
+ """
9
+
10
+ !pip install ninja
11
+ !sudo apt-get update
12
+ !sudo apt-get install build-essential
13
+ !rm -rf ~/.cache/torch_extensions/
14
+
15
+ !pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121
16
+
17
+ !pip install requests tqdm
18
+
19
+ # Commented out IPython magic to ensure Python compatibility.
20
+ !git clone https://github.com/NVlabs/stylegan2-ada-pytorch.git
21
+ # %cd stylegan2-ada-pytorch
22
+
23
+ !wget https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
24
+
25
+ from google.colab import drive
26
+ drive.mount('/content/drive')
27
+
28
+ # Import necessary libraries
29
+ import torch
30
+ from torch.utils.data import Dataset, DataLoader
31
+ from torchvision import transforms
32
+ import os
33
+ import PIL.Image
34
+ import numpy as np
35
+ import dnnlib
36
+ import legacy
37
+
38
+ # Load pre-trained model
39
+ import pickle
40
+
41
+ with open('ffhq.pkl', 'rb') as f:
42
+ data = pickle.load(f)
43
+
44
+ print(data.keys())
45
+
46
+ # Check if CUDA is available, and if so, move the model to GPU
47
+ if torch.cuda.is_available():
48
+ G = data['G_ema'].cuda()
49
+ D = data['D'].cuda()
50
+ print("Model loaded on GPU.")
51
+ else:
52
+ G = data['G_ema'] # Keep the model on CPU
53
+ D = data['D'].cuda()
54
+ print("CUDA not available, model loaded on CPU.")
55
+
56
+ print(type(G))
57
+ print(G)
58
+
59
+ print(type(D))
60
+ print(D)
61
+
62
+ def generate_images(G, z=None, num_images=1, truncation_psi=0.7, seed=None):
63
+ if seed is not None:
64
+ torch.manual_seed(seed)
65
+
66
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
67
+ if torch.cuda.is_available():
68
+ print("CUDA is available. Using GPU.")
69
+ else:
70
+ print("CUDA is not available. Using CPU.")
71
+
72
+ if z is None:
73
+ z = torch.randn((num_images, G.z_dim), device=device)
74
+ else:
75
+ z = z.to(device)
76
+
77
+ print("Latent vectors prepared.")
78
+
79
+ ws = G.mapping(z, None, truncation_psi=truncation_psi)
80
+ print("Mapping done.")
81
+
82
+ img = G.synthesis(ws, noise_mode='const')
83
+ print("Synthesis done.")
84
+
85
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
86
+ print("Image tensor converted to numpy array.")
87
+
88
+ return [Image.fromarray(i) for i in img]
89
+
90
+ from PIL import Image # Add this import
91
+
92
+ # Generate 4 images
93
+ generated_images = generate_images(G, num_images=4, truncation_psi=0.7, seed=42)
94
+
95
+ # Generate and display 4 images
96
+ #generated_images = generate_images(G, num_images=4, truncation_psi=0.7, seed=42)
97
+ print("Images generated.")
98
+ for i, img in enumerate(generated_images):
99
+ display(img)
100
+
101
+ # Fine-tuning setup
102
+ import os
103
+ from torch.utils.data import Dataset, DataLoader
104
+ from torchvision import transforms
105
+
106
+ # Advanced fine-tuning setup
107
+ import torch
108
+ from torch.optim import Adam
109
+ from torch.utils.data import Dataset, DataLoader
110
+ from torchvision import transforms
111
+ from torchvision.utils import save_image
112
+
113
+ # Custom dataset
114
+ class CustomDataset(Dataset):
115
+ def __init__(self, image_dir, transform=None):
116
+ self.image_dir = image_dir
117
+ self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg') or f.endswith('.png')]
118
+
119
+ if len(self.image_files) == 0:
120
+ raise ValueError(f"No image files found in directory: {image_dir}")
121
+
122
+ self.transform = transform
123
+
124
+ def __len__(self):
125
+ return len(self.image_files)
126
+
127
+ def __getitem__(self, idx):
128
+ img_path = os.path.join(self.image_dir, self.image_files[idx])
129
+ image = PIL.Image.open(img_path).convert('RGB')
130
+ if self.transform:
131
+ image = self.transform(image)
132
+ return image
133
+
134
+ !pip install datasets
135
+ from datasets import load_dataset
136
+ ds = load_dataset("TrainingDataPro/black-people-liveness-detection-video-dataset")
137
+
138
+ # Set up data loading
139
+ transform = transforms.Compose([
140
+ transforms.Resize((G.img_resolution, G.img_resolution)),
141
+ transforms.ToTensor(),
142
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
143
+ ])
144
+
145
+ dataset = CustomDataset("/content/drive/MyDrive/Colab Notebooks/part2", transform=transform)
146
+ dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
147
+
148
+ # Fine-tuning loop (advanced)
149
+ optimizer_g = Adam(G.parameters(), lr=0.0001, betas=(0, 0.99), eps=1e-8)
150
+ optimizer_d = Adam(D.parameters(), lr=0.0001, betas=(0, 0.99), eps=1e-8)
151
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.99)
152
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.99)
153
+
154
+ num_epochs = 100
155
+ for epoch in range(num_epochs):
156
+ for batch in dataloader:
157
+ real_images = batch.cuda()
158
+
159
+ # Generate fake images
160
+ z = torch.randn([batch.shape[0], G.z_dim]).cuda()
161
+ fake_images = G(z, None) # Replace None with actual labels if available
162
+
163
+ # Ensure fake_images requires gradients
164
+ fake_images.requires_grad_(True)
165
+
166
+ # Prepare a dummy label (replace with actual labels if available)
167
+ c = None # Replace 'some_dimension' with the correct size
168
+
169
+ # Compute loss (using BCE loss)
170
+ g_loss = torch.mean(torch.log(1 - D(fake_images, c))) # Pass the dummy label to D
171
+ d_loss_real = torch.mean(torch.log(D(real_images, c))) # Pass the dummy label to D
172
+ d_loss_fake = torch.mean(torch.log(1 - D(fake_images, c))) # Pass the dummy label to D
173
+ d_loss = -d_loss_real - d_loss_fake
174
+
175
+ # Check for NaNs
176
+ if torch.isnan(g_loss) or torch.isnan(d_loss):
177
+ #print("NaN detected in loss. Skipping update.")
178
+ continue
179
+
180
+ # Update generator
181
+ optimizer_g.zero_grad()
182
+ g_loss.backward()
183
+ torch.nn.utils.clip_grad_norm_(G.parameters(), max_norm=1) # Clip gradients
184
+ optimizer_g.step()
185
+
186
+ # Update discriminator
187
+ optimizer_d.zero_grad()
188
+ d_loss.backward()
189
+ torch.nn.utils.clip_grad_norm_(D.parameters(), max_norm=1) # Clip gradients
190
+ optimizer_d.step()
191
+
192
+ # Update learning rate
193
+ scheduler_g.step()
194
+ scheduler_d.step()
195
+
196
+ print(f"Epoch {epoch+1}/{num_epochs}, G Loss: {g_loss.item()}")
197
+
198
+ # Save the full model
199
+ torch.save(G, '/content/drive/MyDrive/full_model_stylegan.pt')
200
+
201
+ for param in G.parameters():
202
+ param.requires_grad = True
203
+
204
+ # Generate new images with fine-tuned model
205
+ z = torch.randn([4, G.z_dim]).cuda() # Generate 4 random latent vectors
206
+ imgs = generate_images(G, z, truncation_psi=0.7)
207
+
208
+ # Display each generated image
209
+ for img in imgs:
210
+ display(img)
211
+
212
+ # Save the fine-tuned model
213
+ torch.save(G.state_dict(), 'fine_tuned_stylegan.pth')
214
+
215
+ !pip install gradio
216
+
217
+ import torch
218
+ import torchvision.transforms as transforms
219
+ from PIL import Image
220
+ import gradio as gr
221
+ from tqdm import tqdm
222
+
223
+ def optimize_latent_vector(G, target_image, num_iterations=1000):
224
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
225
+ target_image = transforms.Resize((G.img_resolution, G.img_resolution))(target_image)
226
+ target_tensor = transforms.ToTensor()(target_image).unsqueeze(0).to(device)
227
+ target_tensor = (target_tensor * 2) - 1 # Normalize to [-1, 1]
228
+
229
+ latent_vector = torch.randn((1, G.z_dim), device=device, requires_grad=True)
230
+ optimizer = torch.optim.Adam([latent_vector], lr=0.1)
231
+
232
+ for i in tqdm(range(num_iterations), desc="Optimizing latent vector"):
233
+ optimizer.zero_grad()
234
+
235
+ generated_image = G(latent_vector, None)
236
+ loss = torch.nn.functional.mse_loss(generated_image, target_tensor)
237
+
238
+ loss.backward()
239
+ optimizer.step()
240
+
241
+ if (i + 1) % 100 == 0:
242
+ print(f'Iteration {i+1}/{num_iterations}, Loss: {loss.item()}')
243
+
244
+ return latent_vector.detach()
245
+
246
+ def generate_from_upload(uploaded_image):
247
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
248
+
249
+ # Optimize latent vector for the uploaded image
250
+ optimized_z = optimize_latent_vector(G, uploaded_image)
251
+
252
+ # Generate variations
253
+ num_variations = 4
254
+ variation_strength = 0.1
255
+ varied_z = optimized_z + torch.randn((num_variations, G.z_dim), device=device) * variation_strength
256
+
257
+ # Generate the variations
258
+ with torch.no_grad():
259
+ imgs = G(varied_z, c=None, truncation_psi=0.7, noise_mode='const')
260
+
261
+ imgs = (imgs * 127.5 + 128).clamp(0, 255).to(torch.uint8)
262
+ imgs = imgs.permute(0, 2, 3, 1).cpu().numpy()
263
+
264
+ # Convert the generated image tensors to PIL Images
265
+ generated_images = [Image.fromarray(img) for img in imgs]
266
+
267
+ # Return the images separately
268
+ return generated_images[0], generated_images[1], generated_images[2], generated_images[3]
269
+
270
+ # Create the Gradio interface
271
+ iface = gr.Interface(
272
+ fn=generate_from_upload,
273
+ inputs=gr.Image(type="pil"),
274
+ outputs=[gr.Image(type="pil") for _ in range(4)],
275
+ title="StyleGAN Image Variation Generator"
276
+ )
277
+
278
+ # Launch the Gradio interface
279
+ iface.launch(share=True, debug=True)
280
+
281
+ # If you want to test it without the Gradio interface:
282
+ """
283
+ # Load an image explicitly
284
+ image_path = "path/to/your/image.jpg"
285
+ image = Image.open(image_path)
286
+
287
+ # Call the generate method explicitly
288
+ generated_images = generate_from_upload(image)
289
+
290
+ # Display the generated images
291
+ for img in generated_images:
292
+ img.show()
293
+ """
294
+