K00B404 commited on
Commit
454a3ac
1 Parent(s): ea2ff8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +313 -0
app.py CHANGED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import torch.optim as optim
6
+ import torchvision.transforms as transforms
7
+ from torch.utils.data import DataLoader, Dataset, Subset
8
+ from PIL import Image
9
+ from datasets import load_dataset
10
+ from transformers import T5EncoderModel, T5Tokenizer, DistilBertModel, DistilBertTokenizer
11
+ import matplotlib.pyplot as plt
12
+ from rich import print as rp
13
+ from credits import HUGGINGFACE_TOKEN, HUGGINGFACE_W_TOKEN, WANDB_API_KEY
14
+ import wandb # Import wandb
15
+ import torchvision.utils as vutils # To save image grids
16
+ write_token = HUGGINGFACE_W_TOKEN
17
+ read_token = HUGGINGFACE_TOKEN
18
+
19
+ class TextEncoder(nn.Module):
20
+ def __init__(self, encoder_model_name, encoder_type="t5", device='cpu'):
21
+ super(TextEncoder, self).__init__()
22
+ self.device = device
23
+ self.encoder_type = encoder_type
24
+ if encoder_type == "t5":
25
+ self.tokenizer = T5Tokenizer.from_pretrained(encoder_model_name, cache_dir='./models', token=read_token)
26
+ self.encoder = T5EncoderModel.from_pretrained(encoder_model_name, cache_dir='./models', token=read_token)
27
+ elif encoder_type == "distilbert":
28
+ self.tokenizer = DistilBertTokenizer.from_pretrained(encoder_model_name, cache_dir='./models', token=read_token)
29
+ self.encoder = DistilBertModel.from_pretrained(encoder_model_name, cache_dir='./models', token=read_token)
30
+ else:
31
+ raise ValueError(f"Invalid encoder_type: {encoder_type}. Choose from 't5' or 'distilbert'.")
32
+ self.encoder.to(self.device)
33
+
34
+ def encode_text(self, text):
35
+ if isinstance(text, str):
36
+ text = [text]
37
+
38
+ embeddings = []
39
+ for t in text:
40
+ inputs = self.tokenizer(t, return_tensors="pt", padding=True, truncation=True).to(self.device)
41
+ outputs = self.encoder(**inputs)
42
+ embeddings.append(outputs.last_hidden_state[:, 0, :])
43
+
44
+ return torch.stack(embeddings) # Combine embeddings into a batch
45
+
46
+ class ConditionalDiffusionModel(nn.Module):
47
+ def __init__(self):
48
+ super(ConditionalDiffusionModel, self).__init__()
49
+ self.model = nn.Sequential(
50
+ nn.Linear(512, 768), # Adjusted from 512 to 768
51
+ nn.ReLU(),
52
+ nn.Linear(768, 64),
53
+ nn.ReLU(),
54
+ nn.Linear(64, 64)
55
+ )
56
+
57
+ def forward(self, text_embeddings):
58
+ return self.model(text_embeddings)
59
+
60
+ class SuperResolutionDiffusionModel(nn.Module):
61
+ def __init__(self):
62
+ super(SuperResolutionDiffusionModel, self).__init__()
63
+ self.model = nn.Sequential(
64
+ nn.Conv2d(3, 64, kernel_size=3, padding=1), # 3 is the number of color channels
65
+ nn.ReLU(),
66
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
67
+ nn.ReLU(),
68
+ nn.Conv2d(64, 3, kernel_size=3, padding=1)
69
+ )
70
+
71
+ def forward(self, input_image):
72
+ return self.model(input_image)
73
+
74
+ class TextToImageModel(nn.Module):
75
+ def __init__(self, text_encoder, conditional_diffusion_model, super_resolution_diffusion_model):
76
+ super(TextToImageModel, self).__init__()
77
+ self.text_encoder = text_encoder
78
+ self.conditional_diffusion_model = conditional_diffusion_model
79
+ self.super_resolution_diffusion_model = super_resolution_diffusion_model
80
+
81
+ def forward(self, text):
82
+ text_embeddings = self.text_encoder.encode_text(text)
83
+ image_embeddings = self.conditional_diffusion_model(text_embeddings)
84
+ input_image = torch.rand((image_embeddings.shape[0], 3, 128, 128)).to(text_embeddings.device)
85
+ for _ in range(6): # Upsample the image 6 times
86
+ input_image = self.super_resolution_diffusion_model(input_image)
87
+ return input_image
88
+
89
+ class CustomDataset(Dataset):
90
+ def __init__(self, annotations_file, train_img_dir, size_sqr=128):
91
+ with open(annotations_file, 'r') as f:
92
+ lines = f.readlines()
93
+ self.transform = transforms.Compose([
94
+ transforms.Resize((size_sqr, size_sqr)),
95
+ transforms.ToTensor(),
96
+ ])
97
+ self.img_labels = [line.strip().split(' ', 1) for line in lines]
98
+ self.train_img_dir = train_img_dir
99
+
100
+ def __len__(self):
101
+ return len(self.img_labels)
102
+
103
+ def __getitem__(self, idx):
104
+ img_name, text = self.img_labels[idx]
105
+ img_path = os.path.join(self.train_img_dir, img_name)
106
+ image = Image.open(img_path).convert("RGB")
107
+ if self.transform:
108
+ image = self.transform(image)
109
+ return text, image
110
+
111
+
112
+ class HuggingDataset(Dataset): # New class for HuggingFace dataset
113
+ def __init__(self, dataset_name="vera365/lexica_dataset", size_sqr=128, limit=None):
114
+ self.dataset = load_dataset(dataset_name, token=read_token, cache_dir='./datasets')
115
+ self.transform = transforms.Compose([
116
+ transforms.Resize((size_sqr, size_sqr)),
117
+ transforms.ToTensor(),
118
+ ])
119
+ # Apply limit if specified
120
+ if limit is not None:
121
+ self.dataset = Subset(self.dataset["train"], range(limit))
122
+
123
+ def __len__(self):
124
+ return len(self.dataset["train"])
125
+
126
+ def __getitem__(self, idx):
127
+ item = self.dataset["train"][idx]
128
+
129
+ image =item["image"].convert("RGB")
130
+ text = item["subject"]
131
+ if self.transform:
132
+ image = self.transform(image)
133
+ return text, image
134
+
135
+
136
+
137
+ class StorageHandler:
138
+ def __init__(self, storage_dir="./image_gen_storage", hub_model_name="K00B404/tiny_image_gen", push_dataset=False, dataset_name="K00B404/custom_image_descriptions_dataset"):
139
+ self.model_name = hub_model_name
140
+ self.dataset_name = dataset_name
141
+ self.push_dataset = push_dataset
142
+ self.storage_dir = storage_dir
143
+
144
+ def save_checkpoint(self, model, optimizer, scheduler, epoch, checkpoint_path):
145
+ checkpoint = {
146
+ 'model_state_dict': model.state_dict(),
147
+ 'optimizer_state_dict': optimizer.state_dict(),
148
+ 'scheduler_state_dict': scheduler.state_dict(),
149
+ 'epoch': epoch
150
+ }
151
+ torch.save(checkpoint, checkpoint_path)
152
+
153
+ def load_checkpoint(self, model, optimizer, scheduler, checkpoint_path):
154
+ if os.path.isfile(checkpoint_path):
155
+ checkpoint = torch.load(checkpoint_path)
156
+ model.load_state_dict(checkpoint['model_state_dict'])
157
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
158
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
159
+ epoch = checkpoint['epoch']
160
+ return epoch, scheduler
161
+ else:
162
+ return 0, scheduler
163
+
164
+ def push_dataset(self, dataset):
165
+ if dataset:
166
+ dataset.push_to_hub(self.dataset_name, token=write_token)
167
+
168
+ def push(self, model, tokenizer):
169
+ model.push_to_hub(self.model_name, token=write_token)
170
+ tokenizer.push_to_hub(self.model_name, token=write_token)
171
+ # Optionally push dataset to Hugging Face Hub
172
+
173
+ class Common:
174
+ def __init__(self, device='cpu', wandb_log=False):
175
+ self.wandb_log = wandb_log
176
+ self.device = device
177
+ self.terminal_log = rp
178
+
179
+ if self.wandb_log:
180
+ # Initialize wandb
181
+ #self.wandb = wandb.login(key=WANDB_API_KEY) # Assuming you have already logged in. If not, use: wandb.login(key='YOUR_WANDB_API_KEY')
182
+ self.wandb = wandb.init(project="my-image-generation-project",
183
+ config={
184
+ "learning_rate": learning_rate,
185
+ "batch_size": batch_size,
186
+ "num_epochs": num_epochs,
187
+ "encoder_model": encoder
188
+ })
189
+
190
+
191
+ def train(self, model, optimizer, scheduler, dataloader, criterion, num_epochs, start_epoch, checkpoint_path):
192
+ for epoch in range(start_epoch, num_epochs):
193
+ model.train()
194
+ for i, (text_batch, image_batch) in enumerate(dataloader):
195
+ image_batch = image_batch.to(self.device)
196
+ image_size = 128 # Increase image size
197
+ # Generate a starting image with the correct shape
198
+ input_image = torch.randn((image_batch.shape[0], 3, image_size//8, image_size//8)).to(device)
199
+
200
+ optimizer.zero_grad()
201
+ images = model(text_batch)
202
+ loss = criterion(images, image_batch)
203
+ loss.backward()
204
+ optimizer.step()
205
+
206
+ if self.wandb_log:
207
+ # Log loss and learning rate
208
+ self.wandb.log({"train_loss": loss.item(), "lr": optimizer.param_groups[0]['lr']})
209
+
210
+ if i % 25 == 0:
211
+ # Save a grid of real and generated images for monitoring
212
+ img_grid_real = vutils.make_grid(image_batch[:4], padding=2, normalize=True)
213
+ img_grid_fake = vutils.make_grid(input_image[:4], padding=2, normalize=True)
214
+
215
+ plt.figure(figsize=(15,15))
216
+ plt.subplot(1,2,1)
217
+ plt.axis("off")
218
+ plt.title("Real Images")
219
+ plt.imshow(np.transpose(img_grid_real.cpu(),(1,2,0)))
220
+
221
+ plt.subplot(1,2,2)
222
+ plt.axis("off")
223
+ plt.title("Generated Images")
224
+ plt.imshow(np.transpose(img_grid_fake.cpu(),(1,2,0)))
225
+ plt.savefig(f'generated_images_epoch_{epoch+1}_batch_{i}.png')
226
+ plt.close()
227
+
228
+ # Validation step
229
+ val_loss = self.evaluate(model, dataloader, criterion)
230
+ scheduler.step(val_loss) # Update scheduler with validation loss
231
+
232
+ image = self.test_inference(model, "A house next to a river.")
233
+ self.visualize_image(image, f'generated_image_epoch_{epoch + 1}.png')
234
+ self.terminal_log(f'Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}, Validation Loss: {val_loss}')
235
+ StorageHandler().save_checkpoint(model, optimizer, scheduler, epoch + 1, checkpoint_path)
236
+
237
+ self.terminal_log("Training completed.")
238
+
239
+ def evaluate(self, model, dataloader, criterion):
240
+ model.eval()
241
+ total_loss = 0
242
+ with torch.no_grad():
243
+ for i, (text_batch, image_batch) in enumerate(dataloader):
244
+ image_batch = image_batch.to(self.device)
245
+ images = model(text_batch)
246
+ loss = criterion(images, image_batch)
247
+ total_loss += loss.item()
248
+
249
+ avg_loss = total_loss / len(dataloader)
250
+ self.terminal_log(f'Validation Loss: {avg_loss}')
251
+ return avg_loss
252
+
253
+ def test_inference(self, model, text):
254
+ model.eval()
255
+ with torch.no_grad():
256
+ if isinstance(text, str):
257
+ generated_image = model(text)
258
+ else:
259
+ generated_image = [model(t) for t in text]
260
+ return generated_image
261
+
262
+ def visualize_image(self, image_tensor, filename='generated_image.png'):
263
+ image_tensor = image_tensor.squeeze(0).cpu().detach()
264
+ image_tensor = (image_tensor - image_tensor.min()) / (image_tensor.max() - image_tensor.min()) # Normalize to [0, 1]
265
+ image_tensor = image_tensor.permute(1, 2, 0) # Change from (C, H, W) to (H, W, C)
266
+ plt.imshow(image_tensor)
267
+ plt.savefig(filename)
268
+ #plt.show()
269
+
270
+ if __name__ == "__main__":
271
+ batch_size = 1
272
+ learning_rate = 1e-4
273
+ num_epochs = 500
274
+ encoder = "google-t5/t5-small"
275
+ checkpoint_path = './models/image_gen'
276
+ os.makedirs(checkpoint_path, exist_ok=True)
277
+ checkpoint_file = f"{checkpoint_path}/checkpoint_backup.pth"
278
+ use_huggingface_dataset = False # <-- Toggle between datasets
279
+ limit_huggingface_dataset = 1000 # <-- Set the limit for HuggingFace dataset
280
+ train_img_dir = './train_images'
281
+ annotations_file = f'{train_img_dir}/annotations.txt'
282
+ storage_dir = "./image_gen_storage"
283
+ os.makedirs(storage_dir, exist_ok=True)
284
+ hub_model_name = "K00B404/tiny_image_gen"
285
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
286
+ text_encoder = TextEncoder(encoder, encoder_type="t5", device=device)
287
+ conditional_diffusion_model = ConditionalDiffusionModel()
288
+ super_resolution_diffusion_model = SuperResolutionDiffusionModel()
289
+ text_to_image_model = TextToImageModel(text_encoder, conditional_diffusion_model, super_resolution_diffusion_model)
290
+ text_to_image_model.to(device)
291
+
292
+ optimizer = optim.AdamW(text_to_image_model.parameters(), lr=learning_rate)
293
+ criterion = nn.MSELoss()
294
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.001, patience=2)
295
+
296
+ common = Common(device, True)
297
+ storage = StorageHandler(checkpoint_path, hub_model_name)
298
+
299
+ start_epoch, scheduler = storage.load_checkpoint(text_to_image_model, optimizer, scheduler, checkpoint_file)
300
+
301
+ if use_huggingface_dataset:
302
+ dataset = HuggingDataset(size_sqr=128, limit=limit_huggingface_dataset)
303
+ else:
304
+ dataset = CustomDataset(annotations_file, train_img_dir, size_sqr=128)
305
+
306
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
307
+
308
+ common.train(text_to_image_model, optimizer, scheduler, dataloader, criterion, num_epochs, start_epoch, checkpoint_file)
309
+
310
+ sample_texts = ["A big ape.", "A yellow banana."]
311
+ for sample_text in sample_texts:
312
+ generated_image = common.test_inference(text_to_image_model, sample_text)
313
+ common.visualize_image(generated_image)